package com.icesoft.system.safe.utils;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.interfaces.RSAKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Base64;

public class RsaUtil {

	private static final Charset DEFAULT_ENCODING = StandardCharsets.UTF_8;
	private final static int MAX_ENCRYPT_BLOCK = 117;
	private static final String RSA_TRANSFORMATION = "RSA";

	/**
	 * 公钥加密
	 */
	public static String encryptWithBase64(String publicKey, String string) throws IllegalArgumentException {
		RSAPublicKey rsaPublicKey = (RSAPublicKey) KeyPairUtils.getPublicKey(publicKey);
		byte[] binaryData = encrypt(rsaPublicKey, string.getBytes(DEFAULT_ENCODING));
		return Base64.getEncoder().encodeToString(binaryData);
	}

	public static String encryptWithBase64(RSAPublicKey publicKey, String string) throws IllegalArgumentException {
		byte[] binaryData = encrypt(publicKey, string.getBytes(DEFAULT_ENCODING));
		return Base64.getEncoder().encodeToString(binaryData);
	}

	/**
	 * 私钥解密
	 */
	public static String decryptWithBase64(RSAPrivateKey privateKey, String base64String)
			throws IllegalArgumentException {
		byte[] binaryData = decrypt(privateKey, Base64.getDecoder().decode(base64String.getBytes(DEFAULT_ENCODING)));
		return new String(binaryData, DEFAULT_ENCODING);
	}

	/**
	 * 私钥加密
	 */
	public static String encryptWithBase64(RSAPrivateKey privateKey, String string) throws IllegalArgumentException {
		byte[] binaryData = encrypt(privateKey, string.getBytes(DEFAULT_ENCODING));
		return Base64.getEncoder().encodeToString(binaryData);
	}

	/**
	 * 公钥解密
	 */
	public static String decryptWithBase64(RSAPublicKey publicKey, String base64String)
			throws IllegalArgumentException {
		byte[] binaryData = decrypt(publicKey, Base64.getDecoder().decode(base64String.getBytes(DEFAULT_ENCODING)));
		return new String(binaryData, DEFAULT_ENCODING);
	}

	public static byte[] encrypt(RSAPublicKey publicKey, byte[] processData) throws IllegalArgumentException {
		if (publicKey == null) {
			throw new IllegalArgumentException("加密公钥为空, 请设置");
		}
		return encryptByKey(publicKey, processData);
	}

	public static byte[] encrypt(RSAPrivateKey privateKey, byte[] processData) throws IllegalArgumentException {
		if (privateKey == null) {
			throw new IllegalArgumentException("加密公钥为空, 请设置");
		}
		return encryptByKey(privateKey, processData);
	}

	public static byte[] decrypt(RSAPrivateKey privateKey, byte[] cipherData) throws IllegalArgumentException {
		if (privateKey == null) {
			throw new IllegalArgumentException("解密私钥为空, 请设置");
		}
		return decrypt(privateKey, privateKey, cipherData);
	}

	public static byte[] decrypt(RSAPublicKey publicKey, byte[] cipherData) throws IllegalArgumentException {
		if (publicKey == null) {
			throw new IllegalArgumentException("解密公钥为空, 请设置");
		}
		return decrypt(publicKey, publicKey, cipherData);
	}

	private static byte[] encryptByKey(Key key, byte[] processData) throws IllegalArgumentException {
		Cipher cipher = getRsaCipher(key, Cipher.ENCRYPT_MODE);
		return encrypt(cipher, processData);
	}

	private static byte[] decrypt(Key key, RSAKey rsaKey, byte[] cipherData) throws IllegalArgumentException {
		Cipher cipher = getRsaCipher(key, Cipher.DECRYPT_MODE);
		int chunkLen = rsaKey.getModulus().bitLength() / 8;
		return decrypt(cipher, chunkLen, cipherData);
	}

	public static Cipher getRsaCipher(Key key, int mode) {
		return CipherUtils.getCipher(RSA_TRANSFORMATION, key, mode);
	}

	/**
	 * 加密过程
	 *
	 * @param cipher      公钥
	 * @param processData 明文数据
	 */

	private static byte[] encrypt(Cipher cipher, byte[] processData) {
		int inputLen = processData.length;
		try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
			// 对数据分段加密
			return getBytes(cipher, processData, inputLen, out, MAX_ENCRYPT_BLOCK);
		} catch (IllegalBlockSizeException e) {
			throw new IllegalArgumentException("密文长度非法", e);
		} catch (IOException e) {
			throw new IllegalArgumentException("IOException", e);
		} catch (BadPaddingException e) {
			throw new IllegalArgumentException("密钥错误", e);
		}

	}

	private static byte[] getBytes(Cipher cipher, byte[] processData, int inputLen, ByteArrayOutputStream out, int maxEncryptBlock) throws IllegalBlockSizeException, BadPaddingException, IOException {
		int offSet = 0;
		int i = 0;
		byte[] cache;
		while (inputLen - offSet > 0) {
			if (inputLen - offSet > maxEncryptBlock) {
				cache = cipher.doFinal(processData, offSet, maxEncryptBlock);
			} else {
				cache = cipher.doFinal(processData, offSet, inputLen - offSet);
			}
			out.write(cache, 0, cache.length);
			i++;
			offSet = i * maxEncryptBlock;
		}
		byte[] encryptedData = out.toByteArray();
		out.close();
		return encryptedData;
	}

	/**
	 * 解密过程
	 *
	 * @param cipher     私钥
	 * @param cipherData 密文数据
	 * @return 明文
	 */
	private static byte[] decrypt(Cipher cipher, int chunkLen, byte[] cipherData) {
		try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
			int inputLen = cipherData.length;
			// 对数据分段解密
			return getBytes(cipher, cipherData, inputLen, out, chunkLen);
		} catch (IllegalBlockSizeException e) {
			throw new IllegalArgumentException("密文错误", e);
		} catch (IOException e) {
			throw new IllegalArgumentException("解密错误", e);
		} catch (BadPaddingException e) {
			throw new IllegalArgumentException("密钥错误", e);
		}
	}
}